{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# word2vec\n",
    "\n",
    "This notebook is equivalent to `demo-word.sh`, `demo-analogy.sh`, `demo-phrases.sh` and `demo-classes.sh` from the Google examples."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Download some data, for example: [http://mattmahoney.net/dc/text8.zip](http://mattmahoney.net/dc/text8.zip)\n",
    "\n",
    "You could use `make test-data` from the root of the repo."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "import word2vec"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Run `word2phrase` to group up similar words \"Los Angeles\" to \"Los_Angeles\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running command: word2phrase -train ../data/text8 -output ../data/text8-phrases -min-count 5 -threshold 100 -debug 2\n",
      "Starting training using file ../data/text8\n",
      "Words processed: 17000K     Vocab size: 4399K  \n",
      "Vocab size (unigrams + bigrams): 2419827\n",
      "Words in train file: 17005206\n",
      "Words written: 17000K\r"
     ]
    }
   ],
   "source": [
    "word2vec.word2phrase('../data/text8', '../data/text8-phrases', verbose=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This created a `text8-phrases` file that we can use as a better input for `word2vec`.\n",
    "\n",
    "Note that you could easily skip this previous step and use the text data as input for `word2vec` directly.\n",
    "\n",
    "Now train the word2vec model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running command: word2vec -train ../data/text8-phrases -output ../data/text8.bin -size 100 -window 5 -sample 1e-3 -hs 0 -negative 5 -threads 12 -iter 5 -min-count 5 -alpha 0.025 -debug 2 -binary 1 -cbow 1\n",
      "Starting training using file ../data/text8-phrases\n",
      "Vocab size: 98331\n",
      "Words in train file: 15857306\n",
      "Alpha: 0.000002  Progress: 100.03%  Words/thread/sec: 372.45k  4844  Progress: 0.64%  Words/thread/sec: 355.24k   Progress: 0.86%  Words/thread/sec: 338.88k  ss: 1.08%  Words/thread/sec: 352.03k  9%  Words/thread/sec: 362.78k  ds/thread/sec: 352.84k  Progress: 2.28%  Words/thread/sec: 359.77k  05%  Words/thread/sec: 361.45k   ress: 4.60%  Words/thread/sec: 366.96k  gress: 5.03%  Words/thread/sec: 366.26k  : 363.87k  ogress: 6.36%  Words/thread/sec: 367.88k  : 367.25k  ogress: 8.12%  Words/thread/sec: 369.18k  rds/thread/sec: 367.96k   Words/thread/sec: 368.16k  thread/sec: 368.42k  c: 368.42k  022424  Progress: 10.32%  Words/thread/sec: 369.10k  s: 11.70%  Words/thread/sec: 368.98k  0.021998  Progress: 12.02%  Words/thread/sec: 369.54k  ec: 368.99k  .43%  Words/thread/sec: 369.35k   0.021682  Progress: 13.29%  Words/thread/sec: 369.87k  c: 369.82k  ess: 14.85%  Words/thread/sec: 370.06k  ha: 0.021165  Progress: 15.36%  Words/thread/sec: 370.47k    14  Progress: 15.56%  Words/thread/sec: 369.39k  6.17%  Words/thread/sec: 370.14k  89%  Words/thread/sec: 370.19k  read/sec: 371.18k   Words/thread/sec: 369.78k  019968  Progress: 20.14%  Words/thread/sec: 370.77k  s/thread/sec: 370.14k   0.019533  Progress: 21.88%  Words/thread/sec: 370.08k  ead/sec: 369.16k  019294  Progress: 22.84%  Words/thread/sec: 369.53k  019273  Progress: 22.92%  Words/thread/sec: 370.02k  : 22.96%  Words/thread/sec: 370.35k  5  Progress: 23.63%  Words/thread/sec: 370.45k  48%  Words/thread/sec: 369.77k    70.93k  0.018416  Progress: 26.35%  Words/thread/sec: 370.96k  lpha: 0.018409  Progress: 26.38%  Words/thread/sec: 371.22k  .018341  Progress: 26.65%  Words/thread/sec: 370.58k  ds/thread/sec: 370.24k  : 0.017909  Progress: 28.38%  Words/thread/sec: 370.02k  Words/thread/sec: 370.48k  017573  Progress: 29.72%  Words/thread/sec: 370.60k  ogress: 29.90%  Words/thread/sec: 370.63k   Words/thread/sec: 370.62k  ds/thread/sec: 369.86k  : 0.017097  Progress: 31.63%  Words/thread/sec: 369.79k  Words/thread/sec: 370.01k  ec: 370.65k  ha: 0.016674  Progress: 33.32%  Words/thread/sec: 369.90k  03k  d/sec: 370.44k  6k  6136  Progress: 35.47%  Words/thread/sec: 370.04k  ogress: 36.19%  Words/thread/sec: 370.02k  ords/thread/sec: 369.84k   0.015774  Progress: 36.92%  Words/thread/sec: 370.08k  : 370.31k  rogress: 38.29%  Words/thread/sec: 370.34k  %  Words/thread/sec: 370.15k  lpha: 0.015024  Progress: 39.92%  Words/thread/sec: 370.05k  thread/sec: 369.99k  83%  Words/thread/sec: 370.52k    42.58%  Words/thread/sec: 370.53k  57k  %  Words/thread/sec: 370.34k  ss: 44.31%  Words/thread/sec: 370.63k  .013725  Progress: 45.11%  Words/thread/sec: 370.72k  5.83%  Words/thread/sec: 370.75k  3k  pha: 0.013064  Progress: 47.76%  Words/thread/sec: 371.11k    Progress: 47.77%  Words/thread/sec: 370.85k   48.28%  Words/thread/sec: 370.83k  9.07%  Words/thread/sec: 370.85k  5k  : 50.81%  Words/thread/sec: 370.61k  012175  Progress: 51.31%  Words/thread/sec: 370.98k  ha: 0.012117  Progress: 51.54%  Words/thread/sec: 370.46k  ords/thread/sec: 370.71k  ha: 0.011760  Progress: 52.97%  Words/thread/sec: 370.79k    Words/thread/sec: 370.72k  %  Words/thread/sec: 371.20k  /thread/sec: 370.83k  %  Words/thread/sec: 371.00k  s: 55.26%  Words/thread/sec: 370.64k  read/sec: 370.97k  thread/sec: 371.02k  .010570  Progress: 57.73%  Words/thread/sec: 370.82k  ogress: 58.32%  Words/thread/sec: 370.90k  hread/sec: 371.24k  ogress: 59.05%  Words/thread/sec: 371.09k  ec: 371.19k   Progress: 60.84%  Words/thread/sec: 370.92k  d/sec: 371.39k  rds/thread/sec: 371.24k    Words/thread/sec: 371.33k  ogress: 62.38%  Words/thread/sec: 371.44k  ec: 370.98k   Progress: 64.12%  Words/thread/sec: 371.04k  d/sec: 371.08k  gress: 65.46%  Words/thread/sec: 371.32k  c: 371.00k   65.68%  Words/thread/sec: 370.96k  ec: 371.08k  k  1.19k  3%  Words/thread/sec: 371.43k  : 371.02k  1.10k  %  Words/thread/sec: 371.40k  s: 70.40%  Words/thread/sec: 371.22k    Words/thread/sec: 371.19k  2k  8%  Words/thread/sec: 371.23k  c: 371.25k  06652  Progress: 73.41%  Words/thread/sec: 371.21k  .006482  Progress: 74.09%  Words/thread/sec: 371.17k  ds/thread/sec: 371.26k   Progress: 75.80%  Words/thread/sec: 371.40k  hread/sec: 371.15k  d/sec: 371.37k  ec: 371.43k   Progress: 77.12%  Words/thread/sec: 371.44k  d/sec: 371.34k  84  Progress: 78.88%  Words/thread/sec: 371.41k  rogress: 79.82%  Words/thread/sec: 371.38k    Words/thread/sec: 371.33k  76%  Words/thread/sec: 371.27k    .24%  Words/thread/sec: 371.39k   1.47k  /sec: 371.50k  677  Progress: 85.31%  Words/thread/sec: 371.37k   371.54k  ogress: 86.97%  Words/thread/sec: 371.55k  ec: 371.58k  02842  Progress: 88.65%  Words/thread/sec: 371.35k  2644  Progress: 89.44%  Words/thread/sec: 371.48k  ress: 90.16%  Words/thread/sec: 371.39k  : 371.47k  rogress: 91.94%  Words/thread/sec: 371.30k    a: 0.001601  Progress: 93.61%  Words/thread/sec: 371.37k  371.37k  gress: 95.18%  Words/thread/sec: 371.42k  c: 371.55k  Progress: 96.91%  Words/thread/sec: 371.50k  /sec: 371.56k  6  Progress: 98.67%  Words/thread/sec: 371.74k  ead/sec: 372.05k  "
     ]
    }
   ],
   "source": [
    "word2vec.word2vec('../data/text8-phrases', '../data/text8.bin', size=100, binary=True, verbose=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "That created a `text8.bin` file containing the word vectors in a binary format."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Generate the clusters of the vectors based on the trained model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running command: word2vec -train ../data/text8 -output ../data/text8-clusters.txt -size 100 -window 5 -sample 1e-3 -hs 0 -negative 5 -threads 12 -iter 5 -min-count 5 -alpha 0.025 -debug 2 -binary 0 -cbow 1 -classes 100\n",
      "Starting training using file ../data/text8\n",
      "Vocab size: 71291\n",
      "Words in train file: 16718843\n",
      "Alpha: 0.000002  Progress: 100.04%  Words/thread/sec: 384.20k  024813  Progress: 0.76%  Words/thread/sec: 371.85k  3  Progress: 0.96%  Words/thread/sec: 351.62k  gress: 1.16%  Words/thread/sec: 363.80k  ead/sec: 365.33k  352  Progress: 2.60%  Words/thread/sec: 374.29k  ead/sec: 376.16k  s/thread/sec: 379.89k   377.57k  rogress: 4.46%  Words/thread/sec: 379.25k  sec: 380.91k  3527  Progress: 5.91%  Words/thread/sec: 380.41k  hread/sec: 379.70k  25  Progress: 7.51%  Words/thread/sec: 379.69k    ds/thread/sec: 381.44k  40k  1.61k  ss: 9.60%  Words/thread/sec: 381.09k  ress: 10.42%  Words/thread/sec: 381.32k  ead/sec: 381.57k  ords/thread/sec: 381.38k  ec: 381.83k  .67k  ress: 13.45%  Words/thread/sec: 382.04k  .021475  Progress: 14.11%  Words/thread/sec: 380.82k  : 14.33%  Words/thread/sec: 382.34k  1411  Progress: 14.37%  Words/thread/sec: 382.57k  s/thread/sec: 382.69k  82.08k  ress: 16.49%  Words/thread/sec: 382.92k  : 382.62k  0538  Progress: 17.86%  Words/thread/sec: 382.36k  Progress: 18.12%  Words/thread/sec: 381.75k  /sec: 382.32k  9  Progress: 19.79%  Words/thread/sec: 382.95k  ead/sec: 382.90k  21.04%  Words/thread/sec: 383.81k  pha: 0.019590  Progress: 21.65%  Words/thread/sec: 383.21k  %  Words/thread/sec: 382.76k  .08%  Words/thread/sec: 383.30k  gress: 24.29%  Words/thread/sec: 382.82k  8880  Progress: 24.49%  Words/thread/sec: 383.45k  thread/sec: 383.67k  .018475  Progress: 26.11%  Words/thread/sec: 382.76k  ds/thread/sec: 383.30k  4.03k  8069  Progress: 27.74%  Words/thread/sec: 382.81k  rds/thread/sec: 382.97k  ec: 383.34k   Progress: 28.95%  Words/thread/sec: 382.88k  d/sec: 383.14k  62  Progress: 30.56%  Words/thread/sec: 382.91k    ogress: 31.24%  Words/thread/sec: 382.69k  2.99k  ess: 31.83%  Words/thread/sec: 382.53k  82.89k  ress: 33.39%  Words/thread/sec: 382.63k  1k  407  Progress: 34.38%  Words/thread/sec: 382.67k  .69%  Words/thread/sec: 382.63k  pha: 0.016116  Progress: 35.55%  Words/thread/sec: 383.22k  s: 36.36%  Words/thread/sec: 382.72k  %  Words/thread/sec: 383.31k  /thread/sec: 382.56k  ess: 37.61%  Words/thread/sec: 383.28k  hread/sec: 383.29k  .30%  Words/thread/sec: 382.98k  k  ress: 40.90%  Words/thread/sec: 383.65k  .42k  : 383.47k  .41%  Words/thread/sec: 383.40k  k   44.03%  Words/thread/sec: 383.36k  887  Progress: 44.46%  Words/thread/sec: 383.25k  .49%  Words/thread/sec: 383.46k  .45%  Words/thread/sec: 383.37k  k   47.07%  Words/thread/sec: 383.21k  ords/thread/sec: 383.03k  s: 47.88%  Words/thread/sec: 383.47k  c: 383.09k  d/sec: 383.02k  k  09%  Words/thread/sec: 382.87k  Words/thread/sec: 382.95k  ords/thread/sec: 383.11k  rogress: 51.18%  Words/thread/sec: 383.41k  132  Progress: 51.48%  Words/thread/sec: 383.40k   Words/thread/sec: 382.84k  Words/thread/sec: 382.92k     Words/thread/sec: 383.21k  pha: 0.011385  Progress: 54.47%  Words/thread/sec: 383.36k  c: 383.00k  s/thread/sec: 382.88k  ha: 0.010807  Progress: 56.79%  Words/thread/sec: 382.85k  ogress: 57.45%  Words/thread/sec: 382.99k  rds/thread/sec: 382.64k  3%  Words/thread/sec: 382.83k  k  pha: 0.010227  Progress: 59.11%  Words/thread/sec: 382.73k  ress: 59.79%  Words/thread/sec: 383.09k  s/thread/sec: 382.97k  2  Progress: 61.12%  Words/thread/sec: 383.04k  : 0.009671  Progress: 61.33%  Words/thread/sec: 383.14k  Words/thread/sec: 383.02k  pha: 0.009266  Progress: 62.95%  Words/thread/sec: 382.94k  %  Words/thread/sec: 383.05k  0.008910  Progress: 64.37%  Words/thread/sec: 383.36k   5.38%  Words/thread/sec: 382.89k  7k  : 67.00%  Words/thread/sec: 382.98k   0.008093  Progress: 67.64%  Words/thread/sec: 383.03k    Progress: 68.00%  Words/thread/sec: 383.03k   8.42%  Words/thread/sec: 382.95k  8k  : 70.04%  Words/thread/sec: 382.92k  d/sec: 383.11k  read/sec: 383.16k   Words/thread/sec: 382.91k  ad/sec: 383.02k  0k  : 73.07%  Words/thread/sec: 382.96k  lpha: 0.006537  Progress: 73.86%  Words/thread/sec: 382.85k   74.28%  Words/thread/sec: 382.99k  ress: 74.69%  Words/thread/sec: 382.81k  k  ha: 0.005986  Progress: 76.07%  Words/thread/sec: 382.91k  lpha: 0.005823  Progress: 76.72%  Words/thread/sec: 383.03k  3%  Words/thread/sec: 383.32k  ess: 77.73%  Words/thread/sec: 382.86k  thread/sec: 383.21k  : 383.00k  rogress: 79.38%  Words/thread/sec: 383.06k  sec: 382.99k  : 383.36k  ess: 81.12%  Words/thread/sec: 383.06k  : 0.004689  Progress: 81.26%  Words/thread/sec: 383.00k  Words/thread/sec: 383.08k  pha: 0.004276  Progress: 82.91%  Words/thread/sec: 383.15k  %  Words/thread/sec: 383.22k  1%  Words/thread/sec: 383.10k  k  gress: 85.17%  Words/thread/sec: 383.27k   ds/thread/sec: 383.03k  0.003162  Progress: 87.36%  Words/thread/sec: 383.19k  hread/sec: 383.10k  read/sec: 383.10k  02403  Progress: 90.40%  Words/thread/sec: 383.09k  /thread/sec: 383.33k  0.001999  Progress: 92.02%  Words/thread/sec: 383.09k  rds/thread/sec: 383.15k  a: 0.001594  Progress: 93.64%  Words/thread/sec: 383.12k   Words/thread/sec: 383.16k  lpha: 0.001189  Progress: 95.26%  Words/thread/sec: 383.26k  383.28k   7.68%  Words/thread/sec: 383.32k  9k  : 99.38%  Words/thread/sec: 383.64k  "
     ]
    }
   ],
   "source": [
    "word2vec.word2clusters('../data/text8', '../data/text8-clusters.txt', 100, verbose=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "That created a `text8-clusters.txt` with the cluster for every word in the vocabulary"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "import word2vec"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Import the `word2vec` binary file created above"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "model = word2vec.load('../data/text8.bin')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can take a look at the vocabulary as a numpy array"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array(['</s>', 'the', 'of', ..., 'dakotas', 'nias', 'burlesques'],\n",
       "      dtype='<U78')"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.vocab"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Or take a look at the whole matrix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(98331, 100)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.vectors.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 0.14333282,  0.15825513, -0.13715845, ...,  0.05456942,\n",
       "         0.10955409,  0.00693387],\n",
       "       [ 0.09608812,  0.04207754, -0.03806892, ..., -0.03362655,\n",
       "         0.08464055, -0.09016852],\n",
       "       [ 0.12761782, -0.04113456, -0.125634  , ..., -0.05283457,\n",
       "         0.13063692,  0.0944253 ],\n",
       "       ...,\n",
       "       [ 0.03679391,  0.07516006, -0.0347415 , ...,  0.05325394,\n",
       "         0.0078049 ,  0.02930316],\n",
       "       [ 0.01181444, -0.00175774,  0.08741257, ...,  0.09732709,\n",
       "        -0.16259262, -0.05862194],\n",
       "       [-0.13616219,  0.00300416, -0.05052309, ...,  0.07784981,\n",
       "        -0.13167247, -0.09754379]])"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.vectors"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can retreive the vector of individual words"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(100,)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model['dog'].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ 0.09475172,  0.13367875,  0.09377556,  0.17996974,  0.07507402,\n",
       "        0.01064182,  0.14914408,  0.06769233,  0.00867516, -0.01029446])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model['dog'][:10]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can calculate the distance between two or more (all combinations) words."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('dog', 'cat', 0.871114802703574),\n",
       " ('dog', 'fish', 0.593594274883308),\n",
       " ('cat', 'fish', 0.619406762291862)]"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.distance(\"dog\", \"cat\", \"fish\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Similarity\n",
    "\n",
    "We can do simple queries to retreive words similar to \"socks\" based on cosine similarity:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([ 2437,  5478,  7593, 10309,  2428,  9963, 10230,  4812,  2391,\n",
       "         3964]),\n",
       " array([0.8711148 , 0.83590669, 0.7858697 , 0.77791787, 0.761248  ,\n",
       "        0.75992145, 0.75895443, 0.75799265, 0.75621061, 0.75438873]))"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "indexes, metrics = model.similar(\"dog\")\n",
    "indexes, metrics"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This returned a tuple with 2 items:\n",
    "1. numpy array with the indexes of the similar words in the vocabulary\n",
    "2. numpy array with cosine similarity to each word"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can get the words for those indexes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array(['cat', 'cow', 'goat', 'rat', 'bear', 'rabbit', 'pig', 'wolf',\n",
       "       'girl', 'dogs'], dtype='<U78')"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.vocab[indexes]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "There is a helper function to create a combined response as a numpy [record array](http://docs.scipy.org/doc/numpy/user/basics.rec.html)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "rec.array([('cat', 0.8711148 ), ('cow', 0.83590669), ('goat', 0.7858697 ),\n",
       "           ('rat', 0.77791787), ('bear', 0.761248  ),\n",
       "           ('rabbit', 0.75992145), ('pig', 0.75895443),\n",
       "           ('wolf', 0.75799265), ('girl', 0.75621061),\n",
       "           ('dogs', 0.75438873)],\n",
       "          dtype=[('word', '<U78'), ('metric', '<f8')])"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.generate_response(indexes, metrics)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Is easy to make that numpy array a pure python response:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('cat', 0.8711148027035741),\n",
       " ('cow', 0.835906689576186),\n",
       " ('goat', 0.785869702381845),\n",
       " ('rat', 0.7779178650245034),\n",
       " ('bear', 0.761248001192216),\n",
       " ('rabbit', 0.7599214521649291),\n",
       " ('pig', 0.7589544283401916),\n",
       " ('wolf', 0.7579926492523692),\n",
       " ('girl', 0.7562106136633342),\n",
       " ('dogs', 0.7543887265825944)]"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.generate_response(indexes, metrics).tolist()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Phrases"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Since we trained the model with the output of `word2phrase` we can ask for similarity of \"phrases\", basically compained words such as \"Los Angeles\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('san_francisco', 0.8972880096916375),\n",
       " ('san_diego', 0.8764641849445813),\n",
       " ('miami', 0.839973256930157),\n",
       " ('seattle', 0.8341662072927412),\n",
       " ('las_vegas', 0.8333799027822566),\n",
       " ('detroit', 0.8210299188979637),\n",
       " ('chicago', 0.818606214969616),\n",
       " ('cincinnati', 0.8172210251957758),\n",
       " ('atlanta', 0.8101558040323271),\n",
       " ('st_louis', 0.8101060091264682)]"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "indexes, metrics = model.similar('los_angeles')\n",
    "model.generate_response(indexes, metrics).tolist()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Analogies"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Its possible to do more complex queries like analogies such as: `king - man + woman = queen` \n",
    "This method returns the same as `cosine` the indexes of the words in the vocab and the metric"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([1087, 7523, 1145, 8419, 1335,  648, 6768, 1827, 3141,  344]),\n",
       " array([0.30186443, 0.27870405, 0.27706988, 0.27548467, 0.27114285,\n",
       "        0.27089344, 0.26896827, 0.26865763, 0.26679681, 0.26623266]))"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "indexes, metrics = model.analogy(pos=['king', 'woman'], neg=['man'])\n",
    "indexes, metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('queen', 0.3018644250758753),\n",
       " ('empress', 0.2787040539374566),\n",
       " ('prince', 0.2770698840727954),\n",
       " ('aragon', 0.2754846673265303),\n",
       " ('wife', 0.27114284971472047),\n",
       " ('emperor', 0.2708934379826948),\n",
       " ('regent', 0.2689682685888384),\n",
       " ('throne', 0.2686576321715817),\n",
       " ('monarch', 0.26679681320875803),\n",
       " ('son', 0.2662326633886227)]"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.generate_response(indexes, metrics).tolist()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Clusters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "clusters = word2vec.load_clusters('../data/text8-clusters.txt')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can see get the cluster number for individual words"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array(['</s>', 'the', 'of', ..., 'bredon', 'skirting', 'santamaria'],\n",
       "      dtype='<U29')"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "clusters.vocab"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can see get all the words grouped on an specific cluster"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(209,)"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "clusters.get_words_on_cluster(90).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array(['along', 'together', 'associated', 'relations', 'relationship',\n",
       "       'deal', 'combined', 'contact', 'connection', 'bond'], dtype='<U29')"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "clusters.get_words_on_cluster(90)[:10]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can add the clusters to the word2vec model and generate a response that includes the clusters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "model.clusters = clusters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "indexes, metrics = model.analogy(pos=[\"paris\", \"germany\"], neg=[\"france\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('berlin', 0.3222541117490837, 15),\n",
       " ('munich', 0.28194423437060834, 15),\n",
       " ('vienna', 0.27364676745734806, 12),\n",
       " ('moscow', 0.26758512046265515, 74),\n",
       " ('leipzig', 0.26465489144791404, 8),\n",
       " ('st_petersburg', 0.2621400985097239, 61),\n",
       " ('dresden', 0.2565623490594948, 71),\n",
       " ('prague', 0.24847315163847422, 72),\n",
       " ('bonn', 0.2449265540079843, 8),\n",
       " ('z_rich', 0.2402333559866084, 80)]"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.generate_response(indexes, metrics).tolist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
